package org.jvm.rtda.heap.classmember;

import org.jvm.classfile.MemberInfo;
import org.jvm.classfile.attributeinfo.*;
import org.jvm.rtda.heap.*;
import org.jvm.rtda.heap.classLoader.KlassLoaderRegister;
import org.jvm.rtda.heap.symref.ClassRef;
import org.jvm.rtda.thread.Thread;
import org.jvm.util.JvmUtil;

import java.util.*;
import java.util.stream.Collectors;

/**
 * 类方法信息
 *
 * @author 王思翔
 * @date 2023/2/4
 */
public class Method extends ClassMember {

    private int maxStack;

    private int maxLocals;

    private byte[] code;

    private int argSlotCount;

    private MethodDescriptor methodDescriptor;

    private ExceptionHandler[] exceptionTable = new ExceptionHandler[0];

    private LineNumberTableAttribute lineNumberTableAttribute;

    private byte[] parameterAnnotationData = new byte[0];

    private byte[] annotationDefaultData = new byte[0];

    public int getMaxStack() {
        return maxStack;
    }

    public int getMaxLocals() {
        return maxLocals;
    }

    public int getArgSlotCount() {
        return argSlotCount;
    }

    public byte[] getCode() {
        return code;
    }

    public ExceptionHandler[] getExceptionTable() {
        return exceptionTable;
    }

    public Method() {
        super();
    }

    public void setCode(byte[] code) {
        this.code = code;
    }

    public Method(Klass klass, MemberInfo memberInfo) {
        super(klass, memberInfo);
        //可能存在方法没有字节码属性例如抽象方法或者接口方法，这里需要判空
        CodeAttribute codeAttribute = memberInfo.getCodeAttribute();
        this.argSlotCount = calArgSlotCount();
        //如果方法是本地方法，在这里手动注入字节码，并手动指定本地变量表和操作数栈大小
        if (isNative()
                //使用本地方法代理一下defineClass
                || this.getName().equals("defineClass")
                && this.getKlass().getName().equals("java/net/URLClassLoader")
                && this.getDescriptor().equals("(Ljava/lang/String;Lsun/misc/Resource;)Ljava/lang/Class;")) {
            injectCodeAttribute();
            return;
        }
        if (codeAttribute == null) {
            return;
        }
        this.maxStack = codeAttribute.getMaxStack();
        this.maxLocals = codeAttribute.getMaxLocals();
        this.code = codeAttribute.getCode();
        this.exceptionTable = parseExceptionTable(codeAttribute.getExceptionTable());
        this.lineNumberTableAttribute = (LineNumberTableAttribute) Arrays.stream(codeAttribute.getAttributes())
                .filter(attributeInfo -> attributeInfo instanceof LineNumberTableAttribute)
                .findAny().orElse(null);
        this.parameterAnnotationData = memberInfo.getRuntimeVisibleParameterAnnotationsAttributeData();
        this.annotationDefaultData = memberInfo.getAnnotationDefaultAttributeData();
    }

    /**
     * 在当前方法中查找能处理指定位置对应异常的catch块的字节码起始位置
     *
     * @param exceptionClass 待处理异常
     * @param pc             异常发生位置
     * @return 如果无法找到返回-1
     */
    public int findExceptionHandler(Klass exceptionClass, int pc, Thread thread) {
        for (ExceptionHandler exceptionHandler : this.exceptionTable) {
            int startPc = exceptionHandler.getStartPc();
            int endPc = exceptionHandler.getEndPc();
            int handlerPc = exceptionHandler.getHandlerPc();
            //如果catchClass为null代表可以处理所有异常
            Klass catchClass = Optional.ofNullable(exceptionHandler.getCatchType())
                    .map(classRef -> classRef.resolvedClass(thread))
                    .orElse(null);
            if (pc >= startPc && pc < endPc
                    && (catchClass == null || catchClass == exceptionClass || catchClass.isSuperClassOf(exceptionClass))) {
                return handlerPc;
            }
        }
        return -1;
    }

    public boolean isConstructor() {
        return !this.isStatic() && this.name.equals("<init>");
    }

    public boolean isClinit() {
        return this.isStatic() && this.name.equals("<clinit>");
    }

    private ExceptionHandler[] parseExceptionTable(ExceptionTableEntry[] exceptionTable) {
        ExceptionHandler[] res = new ExceptionHandler[exceptionTable.length];
        for (int i = 0; i < exceptionTable.length; i++) {
            ExceptionTableEntry exceptionTableEntry = exceptionTable[i];
            //如果catchType=0，代表可以捕获所有异常。此时exceptionClass=null
            ClassRef exceptionClass = (ClassRef) this.klass.getConstantPool().getConstant(exceptionTableEntry.getCatchType());
            res[i] = new ExceptionHandler(exceptionTableEntry.getStartPc(),
                    exceptionTableEntry.getEndPc(), exceptionTableEntry.getHandlerPc(), exceptionClass);
        }
        return res;
    }

    /**
     * 本地方法，在这里手动注入字节码，并手动指定本地变量表和操作数栈大小
     */
    private void injectCodeAttribute() {
        //暂时把操作数栈长度设置为4
        this.maxStack = 4;
        //本地方法不使用JVM指令操作本地变量表，直接使用参数数量作为变量表大小足够使用
        this.maxLocals = this.argSlotCount;
        /**
         * 实质上本地方法的字节码有且仅有两个，首个字节码一定指向INVOKE_NATIVE指令
         * 第二个字节码则根据返回值类型确定为不同的return指令
         */
        byte returnCode;
        switch (methodDescriptor.getReturnType().charAt(0)) {
            case 'V':
                returnCode = (byte) 0x0b1;
                break;
            case 'D':
                returnCode = (byte) 0x0af;
                break;
            case 'F':
                returnCode = (byte) 0x0ae;
                break;
            case 'J':
                returnCode = (byte) 0x0ad;
                break;
            case 'L':
            case '[':
                returnCode = (byte) 0x0b0;
                break;
            default:
                returnCode = (byte) 0x0ac;
        }
        this.code = new byte[]{(byte) 0x0fe, returnCode};
    }

    /**
     * 计算方法形参占用本地变量表slot的数量
     * 有两个需要注意的地方：
     * 1、对于long和double型需要占两个slot
     * 2、对于非静态的实例方法，jvm会默认传实例的this引用作为方法参数，需要额外占用一个slot
     *
     * @return
     */
    private int calArgSlotCount() {
        MethodDescriptor methodDescriptor = MethodDescriptorParser.parseMethodDescriptor(this.descriptor);
        this.methodDescriptor = methodDescriptor;
        int argSlotCount = methodDescriptor.getParameterTypes().stream()
                .mapToInt(item -> item.equals("J") || item.equals("D") ? 2 : 1)
                .sum();
        if (!isStatic()) {
            argSlotCount++;
        }
        return argSlotCount;
    }

    public Boolean isAbstract() {
        return (this.accessFlags & AccessFlags.ACC_ABSTRACT.getCode()) > 0;
    }

    public Boolean isNative() {
        return (this.accessFlags & AccessFlags.ACC_NATIVE.getCode()) > 0;
    }

    @Override
    public String toString() {
        return "Method{" +
                "class='" + this.getKlass().getName() + '\'' +
                "name='" + name + '\'' +
                ", descriptor='" + descriptor + '\'' +
                '}';
    }

    /**
     * 获取程序计数器对应的源码行号
     *
     * @param pc
     * @return
     */
    public int getLineNumber(int pc) {
        if (isNative()) {
            return -2;
        }
        if (this.lineNumberTableAttribute == null) {
            return -1;
        }
        return this.lineNumberTableAttribute.getLineNumber(pc);
    }

    /**
     * 获取方法参数列表类
     *
     * @return
     */
    public Klass[] getParameterTypes(Thread thread) {
        return this.methodDescriptor.getParameterTypes()
                .stream()
                .map(param -> JvmUtil.toClassName(param))
                .map(name -> KlassLoaderRegister.loadKlass(name, thread))
                .collect(Collectors.toList())
                .toArray(new Klass[0]);
    }

    public Klass getReturnType(Thread thread) {
        String className = JvmUtil.toClassName(this.methodDescriptor.getReturnType());
        return KlassLoaderRegister.loadKlass(className, thread);
    }

    public String getReturnTypeDescriptor() {
        return this.methodDescriptor.getReturnType();
    }

    /**
     * 获取方法异常类型
     *
     * @return
     */
    public Klass[] getExceptionTypes(Thread thread) {
        if (this.exceptionTable == null) {
            return new Klass[0];
        }
        return Arrays.stream(this.exceptionTable)
                .map(ExceptionHandler::getCatchType)
                .map(e -> e.resolvedClass(thread))
                .collect(Collectors.toList())
                .toArray(new Klass[0]);
    }

    public byte[] getParameterAnnotationData() {
        return parameterAnnotationData;
    }

    public byte[] getAnnotationDefaultData() {
        return annotationDefaultData;
    }
}
